library(bridgesampling)
library(EnvStats)
library(R2jags)
library(Hmisc)
library(rms)
library(readr)
library(tidyr)
library(dplyr)
library(magrittr)
library(knitr)
library(MCMCvis)
library(BayesianTools)
library(latex2exp)
library(tableone)

fit_mod <- function(model, dat, chains = 4, 
                    burnin = 5000, iterations = 10000, thinning = 1,
                    varnames = c("b", "tau_u", "u")) {
  
  mod <- jags.model(file=textConnection(model), data = dat, n.chains = chains)
  
  update(mod, burnin)
  
  samples <- coda.samples(mod, variable.names = varnames, 
                          n.iter = iterations, thin = thinning)
  
  return(samples)
}



## Data preparation


dat3 <- read_csv("Sigworth_iScience_Complete_Data.csv")

dat3$trial_num <- 1:nrow(dat3)

# Select relevant variables
# Calculate count of events and study-specific standard error
# Add ID variable
dat2 <- dat3 %>% 
  dplyr::select(., pmid, pmidurl, taxane, all_grade_neuropathy, 
                delivered_taxane, age, dose_gap, number_taxane_arm, trial_num) %>% 
  mutate(n_obs = round(number_taxane_arm * all_grade_neuropathy/100, 0)) %>% 
  mutate(sei = sqrt((all_grade_neuropathy/100)*(1-(all_grade_neuropathy/100))/number_taxane_arm)) %>%
  mutate(id = 1:n())

# Filter to only complete cases with reasonable values of drug and with at least 1 event
dat <- dat2[complete.cases(dat2[,3:9]),] %>%
  filter(delivered_taxane > 0) %>% 
  filter(all_grade_neuropathy > 0) %>%
  mutate(id3 = 1:n())

# Change reference category
dat$taxane <- relevel(factor(dat$taxane), ref="Paclitaxel")

# Create logged version of delivered_taxane
dat$log_tax <- log(dat$delivered_taxane)

# Create sqrt version of delivered_taxane
dat$sqrt_tax <- sqrt(dat$delivered_taxane)

# Create Boxcox version of delivered_taxane
dt <- dat$delivered_taxane
trans <- boxcox(dt, optimize = TRUE)
dat$boxcox_tax <- boxcoxTransform(dt, lambda=trans$lambda)

# Normalize all transforms, stratified by drug type

## Stratify dataset
dat_doc <- subset(dat, taxane == "Docetaxel")
dat_pac <- subset(dat, taxane == "Paclitaxel")

## Normalize linear version
doc_norm <- scale(dat_doc$delivered_taxane)
pac_norm <- scale(dat_pac$delivered_taxane)

## Add normed data back to dataset
dat_doc$deliv_norm <- c(doc_norm)
dat_pac$deliv_norm <- c(pac_norm)

## Normalize logged version
doc_norm2 <- scale(dat_doc$log_tax)
pac_norm2 <- scale(dat_pac$log_tax)

## Add normed logged back to dataset
dat_doc$deliv_log_norm <- c(doc_norm2)
dat_pac$deliv_log_norm <- c(pac_norm2)

## Normalize sqrt version
doc_norm3 <- scale(dat_doc$sqrt_tax)
pac_norm3 <- scale(dat_pac$sqrt_tax)

## Add normed sqrt back to dataset
dat_doc$deliv_sqrt_norm <- c(doc_norm3)
dat_pac$deliv_sqrt_norm <- c(pac_norm3)

## Normalize boxcox version
doc_norm4 <- scale(dat_doc$boxcox_tax)
pac_norm4 <- scale(dat_pac$boxcox_tax)

## Add normed boxcox to dataset
dat_doc$deliv_boxcox_norm <- c(doc_norm4)
dat_pac$deliv_boxcox_norm <- c(pac_norm4)



## Put all data back into one dataset
dat <- rbind(dat_doc, dat_pac)


# Normalize age, create outcome as percentage
dat$age_norm <- c(scale(dat$age))
dat$outcome <- dat$all_grade_neuropathy/100

# Once again verify data is complete, add final ID variable
dat <- dat %>% 
  dplyr::select(-pmid) %>% 
  filter(complete.cases(.)) %>% 
  mutate(id2 = 1:n())

# Calculate logit outcome and the variance of the logit
dat %<>% mutate(logit_outcome = log(outcome/(1-outcome)),
                logit_var = (1/n_obs) + (1/(number_taxane_arm - n_obs)))

# Create dummy codings

dat <- dat %>%
  mutate(pac = taxane == "Paclitaxel",
         doc = taxane == "Docetaxel",
         pac_dose = (taxane == "Paclitaxel")*deliv_norm,
         doc_dose = (taxane == "Docetaxel")*deliv_norm,
         pac_log = (taxane == "Paclitaxel")*deliv_log_norm,
         doc_log = (taxane == "Docetaxel")*deliv_log_norm,
         pac_sqrt = (taxane == "Paclitaxel")*deliv_sqrt_norm,
         doc_sqrt = (taxane == "Docetaxel")*deliv_sqrt_norm,
         pac_box = (taxane == "Paclitaxel")*deliv_boxcox_norm,
         doc_box = (taxane == "Docetaxel")*deliv_boxcox_norm)

# Select data from main file and cleaned file for Table 1
tab_dat <- dat3 %>% 
  dplyr::select(age, Num_Male, taxane, trial_num, year, prior_taxane, prior_platinum, prior_chemo, disease,
                phase, number_taxane_arm, median_followup_duration_mos, dose_cycle,
                dosing_freq, unit_dose, cycle_length_mos, median_cycles, median_delivered_intensity,
                delivered_taxane, dose_gap, all_grade_neuropathy) %>%
  inner_join(dplyr::select(dat, n_obs, outcome, trial_num)) %>%
  mutate(perc_male = Num_Male / number_taxane_arm * 100) 

# Decide on test types and strata for Table 1
nnn <- c("year", "n_obs", "unit_dose", "median_cycles", "median_followup_duration_mos",
         "number_taxane_arm", "delivered_taxane", "all_grade_neuropathy", 
         "age", "perc_male", "cycle_length_mos")
fctvars <- c("phase", "dose_gap")

# Generate and export Table 1
tab3 <- CreateTableOne(strata = "taxane", data = tab_dat, factorVars = fctvars)
tab4 <- print(tab3, noSpaces = TRUE, quote = FALSE, printToggle = FALSE, 
              nonnormal = nnn, showAllLevels = TRUE)
write.csv(tab4, file = "taxane_table_one2.csv")


## Fitting of model with logit outcome, no interaction

# Define model in JAGS syntax

lin_model<-"
model{

#Likelihood
  for( i in 1:n)
    {
      logit_outcome[i]~dnorm(mu[i], 1/(logit_var[i]))
      mu[i]<-b[1]+b[2]*doc[i]+b[3]*pac_dose[i]+b[4]*doc_dose[i]+
        b[5]*age_norm[i]+b[6]*dose_gap[i]+u[id2[i]]
    }

for (j in 1:nid)
  {
    u[j]~dnorm(0, tau_u)
  }
#priors 
for(j in 1:6) { b[j]~dnorm(0, 1E-6)}

tau_u <- pow(sigma, -2)
      sigma ~ dgamma(0.001, 0.001)


}
"

# Put data into a list format for the model
lin_dat <- list(logit_outcome = dat$logit_outcome, logit_var = dat$logit_var,
                doc = dat$doc, pac_dose = dat$pac_dose,
                doc_dose = dat$doc_dose, age_norm = dat$age_norm, 
                dose_gap = dat$dose_gap, id2 = dat$id2, 
                n = length(unique(dat$id2)), nid = length(unique((dat$id2))))


log_model<-"
model{

#Likelihood
  for( i in 1:n)
    {
      logit_outcome[i]~dnorm(mu[i], 1/(logit_var[i]))
      mu[i]<-b[1]+b[2]*doc[i]+b[3]*pac_log[i]+b[4]*doc_log[i]+
        b[5]*age_norm[i]+b[6]*dose_gap[i]+u[id2[i]]
    }

for (j in 1:nid)
  {
    u[j]~dnorm(0, tau_u)
  }
#priors
for(j in 1:6) { b[j]~dnorm(0, 1E-6)}

tau_u <- pow(sigma, -2)
      sigma ~ dgamma(0.001, 0.001)


}
"

# Put data into a list format for the model
log_dat <- list(logit_outcome = dat$logit_outcome, logit_var = dat$logit_var,
                doc = dat$doc, pac_log = dat$pac_log,
                doc_log = dat$doc_log, age_norm = dat$age_norm,
                dose_gap = dat$dose_gap, id2 = dat$id2,
                n = length(unique(dat$id2)), nid = length(unique((dat$id2))))


sqrt_model<-"
model{

#Likelihood
  for( i in 1:n)
    {
      logit_outcome[i]~dnorm(mu[i], 1/(logit_var[i]))
      mu[i]<-b[1]+b[2]*doc[i]+b[3]*pac_sqrt[i]+b[4]*doc_sqrt[i]+
        b[5]*age_norm[i]+b[6]*dose_gap[i]+u[id2[i]]
    }

for (j in 1:nid)
  {
    u[j]~dnorm(0, tau_u)
  }
#priors
for(j in 1:6) { b[j]~dnorm(0, 1E-6)}

tau_u <- pow(sigma, -2)
      sigma ~ dgamma(0.001, 0.001)


}
"

# Put data into a list format for the model
sqrt_dat <- list(logit_outcome = dat$logit_outcome, logit_var = dat$logit_var,
                 doc = dat$doc, pac_sqrt = dat$pac_sqrt,
                 doc_sqrt = dat$doc_sqrt, age_norm = dat$age_norm,
                 dose_gap = dat$dose_gap, id2 = dat$id2,
                 n = length(unique(dat$id2)), nid = length(unique((dat$id2))))


box_model<-"
model{

#Likelihood
  for( i in 1:n)
    {
      logit_outcome[i]~dnorm(mu[i], 1/(logit_var[i]))
      mu[i]<-b[1]+b[2]*doc[i]+b[3]*pac_box[i]+b[4]*doc_box[i]+
        b[5]*age_norm[i]+b[6]*dose_gap[i]+u[id2[i]]
    }

for (j in 1:nid)
  {
    u[j]~dnorm(0, tau_u)
  }
#priors
for(j in 1:6) { b[j]~dnorm(0, 1E-6)}

tau_u <- pow(sigma, -2)
      sigma ~ dgamma(0.001, 0.001)


}
"

# Put data into a list format for the model
box_dat <- list(logit_outcome = dat$logit_outcome, logit_var = dat$logit_var,
                doc = dat$doc, pac_box = dat$pac_box,
                doc_box = dat$doc_box, age_norm = dat$age_norm,
                dose_gap = dat$dose_gap, id2 = dat$id2,
                n = length(unique(dat$id2)), nid = length(unique((dat$id2))))


## linear model with random slope
lin_model_slope <-"
model{

#Likelihood
  for( i in 1:n)
    {
      logit_outcome[i]~dnorm(mu[i], 1/(logit_var[i]))
      mu[i]<-b[1]+b[2]*doc[i]+b[3]*pac_dose[i]+b[4]*doc_dose[i]+
        b[5]*age_norm[i]+b[6]*dose_gap[i]+u[id2[i]] + m[id2[i]]*deliv_norm[i]
    }

for (j in 1:nid)
  {
    u[j]~dnorm(0, tau_u)
    m[j]~dnorm(0, phi_u)
  }
#priors
for(j in 1:6) { b[j]~dnorm(0, 1E-6)}

tau_u <- pow(sigma, -2)
sigma ~ dgamma(0.001, 0.001)

phi_u <- pow(delta, -2)
     delta ~ dgamma(0.001, 0.001)


}
"

# Put data into a list format for the model
lin_dat_slope <- list(logit_outcome = dat$logit_outcome, logit_var = dat$logit_var,
                doc = dat$doc, pac_dose = dat$pac_dose,
                doc_dose = dat$doc_dose, age_norm = dat$age_norm,
                dose_gap = dat$dose_gap, id2 = dat$id2,
                deliv_norm = dat$deliv_norm,
                n = length(unique(dat$id2)), nid = length(unique((dat$id2))))


# Linear model no extra covariates
lin_model_nocov <-"
model{

#Likelihood
  for( i in 1:n)
    {
      logit_outcome[i]~dnorm(mu[i], 1/(logit_var[i]))
      mu[i]<-b[1]+b[2]*doc[i]+b[3]*pac_dose[i]+b[4]*doc_dose[i]+
        u[id2[i]]
    }

for (j in 1:nid)
  {
    u[j]~dnorm(0, tau_u)
  }
#priors
for(j in 1:4) { b[j]~dnorm(0, 1E-6)}

tau_u <- pow(sigma, -2)
      sigma ~ dgamma(0.001, 0.001)


}
"

# Put data into a list format for the model
lin_dat_nocov <- list(logit_outcome = dat$logit_outcome, logit_var = dat$logit_var,
                      doc = dat$doc, pac_dose = dat$pac_dose,
                      doc_dose = dat$doc_dose, id2 = dat$id2,
                      n = length(unique(dat$id2)), nid = length(unique((dat$id2))))

# Linear model no extra covariates and random slope
lin_model_slope_nocov <-"
model{

#Likelihood
  for( i in 1:n)
    {
      logit_outcome[i]~dnorm(mu[i], 1/(logit_var[i]))
      mu[i]<-b[1]+b[2]*doc[i]+b[3]*pac_dose[i]+b[4]*doc_dose[i]+
        u[id2[i]] + m[id2[i]]*deliv_norm[i]
    }

for (j in 1:nid)
  {
    u[j]~dnorm(0, tau_u)
    m[j]~dnorm(0, phi_u)
  }
#priors
for(j in 1:6) { b[j]~dnorm(0, 1E-6)}

tau_u <- pow(sigma, -2)
sigma ~ dgamma(0.001, 0.001)

phi_u <- pow(delta, -2)
      delta ~ dgamma(0.001, 0.001)

}
"

# Put data into a list format for the model
lin_dat_slope_nocov <- list(logit_outcome = dat$logit_outcome, logit_var = dat$logit_var,
                      doc = dat$doc, pac_dose = dat$pac_dose,
                      doc_dose = dat$doc_dose, id2 = dat$id2, deliv_norm = dat$deliv_norm,
                      n = length(unique(dat$id2)), nid = length(unique((dat$id2))))


### Sampling using CODA for parameter summaries

# Define jags object, four chains
lin_samps <- fit_mod(model = lin_model, dat = lin_dat, chains = 4,
                     burnin = 15000,iterations = 500000, 
                     thinning = 50, varnames = c( "b", "tau_u", "u"))

lin_mod <- jags.model(file =textConnection(lin_model), data = lin_dat, n.chains = 4)

log_samps <- fit_mod(model = log_model, dat = log_dat, chains = 4,
                     burnin = 15000,iterations = 500000,
                     thinning = 50, varnames = c( "b", "tau_u", "u"))

log_mod <- jags.model(file =textConnection(log_model), data = log_dat, n.chains = 4)

sqrt_samps <- fit_mod(model = sqrt_model, dat = sqrt_dat, chains = 4,
                      burnin = 15000,iterations = 500000,
                      thinning = 50, varnames = c( "b", "tau_u", "u"))

sqrt_mod <- jags.model(file =textConnection(sqrt_model), data = sqrt_dat, n.chains = 4)

box_samps <- fit_mod(model = box_model, dat = box_dat, chains = 4,
                     burnin = 15000,iterations = 500000,
                     thinning = 50, varnames = c( "b", "tau_u", "u"))

box_mod <- jags.model(file =textConnection(box_model), data = box_dat, n.chains = 4)

lin_samps_slope <- fit_mod(model = lin_model_slope, dat = lin_dat_slope, chains = 4,
                     burnin = 15000,iterations = 500000,
                     thinning = 50, varnames = c( "b", "tau_u", "u", "m"))

lin_mod_slope <- jags.model(file =textConnection(lin_model_slope), data = lin_dat_slope, n.chains = 4)

lin_samps_nocov <- fit_mod(model = lin_model_nocov, dat = lin_dat_nocov, chains = 4,
                           burnin = 15000,iterations = 500000,
                           thinning = 50, varnames = c( "b", "tau_u", "u"))

lin_mod_nocov <- jags.model(file =textConnection(lin_model_nocov), data = lin_dat_nocov, n.chains = 4)

lin_samps_slope_nocov <- fit_mod(model = lin_model_slope_nocov,
                                 dat = lin_dat_slope_nocov, chains = 4,
                           burnin = 15000,iterations = 500000,
                           thinning = 50, varnames = c( "b", "tau_u", "u"))

lin_mod_slope_nocov <- jags.model(file =textConnection(lin_model_slope_nocov),
                                  data = lin_dat_slope_nocov, n.chains = 4)

dic_lin <- dic.samples(model = lin_mod, n.iter=500000, thin = 50, type="pD")
dic_log <- dic.samples(model = log_mod, n.iter=500000, thin = 50, type="pD")
dic_sqrt <- dic.samples(model = sqrt_mod, n.iter=500000, thin = 50, type="pD")
dic_box <- dic.samples(model = box_mod, n.iter=500000, thin = 50, type="pD")
dic_lin_slope <- dic.samples(model = lin_mod_slope, n.iter=500000, thin = 50, type="pD")
dic_lin_nocov <- dic.samples(model = lin_mod_nocov, n.iter=500000, thin = 50, type="pD")
dic_lin_slope_nocov <- dic.samples(model = lin_mod_slope_nocov,
                                  n.iter=500000, thin = 50, type="pD")

dic_lin
dic_log
dic_sqrt
dic_box
dic_lin_slope
dic_lin_nocov
dic_lin_slope_nocov


model_summary <- function(samples, ncovars){
  
  tab <- MCMCsummary(samples[, 1:(ncovars + 1)])
  
  mcmc <- MCMCplot(samples[,1:(ncovars+1)], sz_thick = 0)
  
  reg_plot <- plot(samples[,1:(ncovars+1)])
  
  acf_plot <- acfplot(samples[,1:(ncovars+1)], lag.max = 100)
  
  gel_plot <- gelman.plot(samples[,1:(ncovars+1)])
  
  vals <- list(tab, mcmc, reg_plot, acf_plot, gel_plot)
  
  return(vals)
}
# Summarize model fit for chosen model and check diagnostic plots
model_summary(lin_samps, ncovars = 6)



# Grab scaling parameters from both drugs
d0_bar <- attributes(pac_norm)$`scaled:center`
s0 <- attributes(pac_norm)$`scaled:scale`
d1_bar <- attributes(doc_norm)$`scaled:center`
s1 <- attributes(doc_norm)$`scaled:scale`


# Put all samples into one matrix
mm <- rbind(lin_samps[[1]], lin_samps[[2]], lin_samps[[3]], lin_samps[[4]])

# Look at cumulative paclitaxel doses between 400 and 1600
possible_doses <- seq(400,1600,by = 5)
n <- length(possible_doses)

# Pull betas needed for relationship
vars <- mm[,2:4]

# Setup dataframe to do calculations along each row
vals <- data.frame(d0 = rep(possible_doses, each = 40000),
                   B2 = rep(vars[,1], n),
                   B3 = rep(vars[,2], n),
                   B4 = rep(vars[,3], n),
                   d0_bar = d0_bar,
                   s0 = s0,
                   d1_bar = d1_bar,
                   s1 = s1)

# Calculate equivalent dose at each set of sample parameters
vals <- vals %>%
  mutate(d1 = (-B2 + (B4 * d1_bar)/s1 + B3*(d0 - d0_bar)/s0)/(B4/s1))

# At each paclitaxel dose, summarize equivalent dose across 40,000 samples
vals2 <- vals %>% 
  group_by(d0) %>% 
  dplyr::summarize(lower = quantile(d1, probs = .025), 
                   upper = quantile(d1, probs = 0.975),
                   median = median(d1)) %>%
  mutate(lower = ifelse(lower < 0, 0, lower))

colors <- c("Median" = "#81144e", "Credible interval" = "#81144e",
            "Median sqrt" = "#FAA31B", "Credible interval sqrt" = "#FAA31B")

# Plot equivalence relationship with 95% credible intervals
ggplot(aes(x = d0), data = vals2) +
  geom_line(aes(x = d0, y = median, color = "Median"), data = vals2, size = 2) +
  geom_ribbon(aes(ymin=lower, ymax = upper, fill = "Credible interval"), 
              alpha = .3) +
  geom_ribbon(aes(ymin=lower, ymax = upper), 
              data = vals2, 
              color = colors["Credible interval"], 
              show.legend=FALSE, fill = NA, size = 1, linetype = 3) +
  xlab(TeX("Delivered paclitaxel ($mg/m^2$)")) +
  ylab(TeX("Delivered docetaxel ($mg/m^2$)")) +
  #ggtitle("Dose equivalence of docetaxel and paclitaxel") +
  theme_light() +
  theme(text = element_text(size=15),
        axis.text.x = element_text(size = 10),
        legend.position = c(.05, .9),
        legend.justification = c("left", "top"),
        legend.box.just = "left",
        legend.margin = margin(0, 0, 0, 0),
        legend.spacing = unit(0, 'cm'),
        legend.background = element_rect(fill="transparent")
  ) +
  labs(fill = NULL,
       color = NULL) +
  scale_color_manual(values = colors) +
  scale_fill_manual(values=colors)
